{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "df8313a0",
   "metadata": {},
   "source": [
    "# Chapter 10: Train A Transformer to Translate English to French\n",
    "\n",
    "This chapter covers\n",
    "\n",
    "* Tokenizing English and French phrases to subwords\n",
    "* Understanding word embedding and positional encoding \n",
    "* Training a Transformer from scratch to translate English to French\n",
    "* Using the trained Transformer to translate an English phrase into French\n",
    "\n",
    "In the last chapter, we built a Transformer from scratch that can translate between any two languages, based on the paper Attention Is All You Need.  Specifically, we implemented the self-attention mechanism, using query, key, and value vectors to calculate scaled dot product attention (SDPA). \n",
    "\n",
    "To have a deeper understanding of self-attention and Transformers, we'll use English-to-French translation as our case study in this chapter. By exploring the process of training a model for converting English sentences into French, you will gain a comprehensive understanding of the Transformer's architecture and the functioning of the attention mechanism. \n",
    "\n",
    "Picture yourself having amassed a collection of over 47,000 English-to-French translation pairs. Your objective is to train the encoder-decoder Transformer from the last chapter using this dataset. This chapter will walk you through all phases of the project. You’ll first use subword tokenization to break English and French phrases into tokens. You’ll then build your English and French vocabularies that contain all unique tokens in each language. The vocabularies allow you to represent English and French phrases as sequences of indexes. After that, you’ll use word embedding to transform these indexes (essentially one-hot vectors) into compact vector representations. We’ll add positional encodings to the word embeddings to form input embeddings. Positional encodings allow the Transformer to know the ordering of tokens in the sequence. \n",
    "\n",
    "Finally, you’ll train the encoder-decoder Transformer from Chapter 9 to translate English to French by using the collection of English-to-French translations. After training, you’ll learn to translate common English phrases with the trained Transformer. Specifically, you’ll use the encoder to capture the meaning of the English phrase. You’ll then use the decoder in the trained Transformer to generate the French translation in an autoregressive manner, starting with the beginning token \"BOS\". In each time step, the decoder generates the most likely next token based on previously generated tokens and the encoder’s output, until the predicted token is \"EOS\", which signals the end of the sentence. The trained model can translate common English phrases accurately as if you were using Google Translate for the task."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ee4d965",
   "metadata": {},
   "source": [
    "# 1\tSubword tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d30d2dc3",
   "metadata": {},
   "source": [
    "## 1.1\tTokenize English and French phrases\n",
    "First go to https://gattonweb.uky.edu/faculty/lium/gai/en2fr.zip to download zip file that contains the 47,000 English to French translations that I collected from various sources. Unzip the file and place en2fr.csv in the folder /files/ on your computer. We'll load the data and take a look as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aafd4b37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "there are 47173 examples in the training data\n",
      "How are you?\n",
      "Comment êtes-vous?\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df=pd.read_csv(\"files/en2fr.csv\")\n",
    "num_examples=len(df)\n",
    "print(f\"there are {num_examples} examples in the training data\")\n",
    "print(df.iloc[30856][\"en\"])\n",
    "print(df.iloc[30856][\"fr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ad38c15f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a06a33bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['i</w>', 'don</w>', \"'t</w>\", 'speak</w>', 'fr', 'ench</w>', '.</w>']\n",
      "['je</w>', 'ne</w>', 'parle</w>', 'pas</w>', 'franc', 'ais</w>', '.</w>']\n",
      "['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
      "['comment</w>', 'et', 'es-vous</w>', '?</w>']\n"
     ]
    }
   ],
   "source": [
    "from transformers import XLMTokenizer\n",
    "\n",
    "tokenizer = XLMTokenizer.from_pretrained(\"xlm-clm-enfr-1024\")\n",
    "\n",
    "tokenized_en=tokenizer.tokenize(\"I don't speak French.\")\n",
    "print(tokenized_en)\n",
    "tokenized_fr=tokenizer.tokenize(\"Je ne parle pas français.\")\n",
    "print(tokenized_fr)\n",
    "print(tokenizer.tokenize(\"How are you?\"))\n",
    "print(tokenizer.tokenize(\"Comment êtes-vous?\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7affb028",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build dictionaries\n",
    "from collections import Counter\n",
    "\n",
    "en=df[\"en\"].tolist()\n",
    "\n",
    "en_tokens=[[\"BOS\"]+tokenizer.tokenize(x)+[\"EOS\"] for x in en]        \n",
    "PAD=0\n",
    "UNK=1\n",
    "# apply to English \n",
    "word_count=Counter()\n",
    "for sentence in en_tokens:\n",
    "    for word in sentence:\n",
    "        word_count[word]+=1\n",
    "frequency=word_count.most_common(50000)        \n",
    "total_en_words=len(frequency)+2\n",
    "en_word_dict={w[0]:idx+2 for idx,w in enumerate(frequency)}\n",
    "en_word_dict[\"PAD\"]=PAD\n",
    "en_word_dict[\"UNK\"]=UNK\n",
    "# another dictionary to map numbers to tokens\n",
    "en_idx_dict={v:k for k,v in en_word_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0d0abaf5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[15, 100, 38, 377, 476, 574, 5]\n"
     ]
    }
   ],
   "source": [
    "enidx=[en_word_dict.get(i,UNK) for i in tokenized_en]   \n",
    "print(enidx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "02572e98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['i</w>', 'don</w>', \"'t</w>\", 'speak</w>', 'fr', 'ench</w>', '.</w>']\n",
      "i don't speak french. \n"
     ]
    }
   ],
   "source": [
    "entokens=[en_idx_dict.get(i,\"UNK\") for i in enidx]   \n",
    "print(entokens)\n",
    "en_phrase=\"\".join(entokens)\n",
    "en_phrase=en_phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    en_phrase=en_phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(en_phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c94fa3cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[157, 17, 22, 26]\n",
      "['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
      "how are you? \n"
     ]
    }
   ],
   "source": [
    "# exercise 10.1\n",
    "tokens=['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
    "indexes=[en_word_dict.get(i,UNK) for i in tokens]   \n",
    "print(indexes)\n",
    "tokens=[en_idx_dict.get(i,\"UNK\") for i in indexes]   \n",
    "print(tokens)\n",
    "phrase=\"\".join(tokens)\n",
    "phrase=phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    phrase=phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1f2bd162",
   "metadata": {},
   "outputs": [],
   "source": [
    "# do the same for French phrases\n",
    "fr=df[\"fr\"].tolist()       \n",
    "fr_tokens=[[\"BOS\"]+tokenizer.tokenize(x)+[\"EOS\"] for x in fr] \n",
    "word_count=Counter()\n",
    "for sentence in fr_tokens:\n",
    "    for word in sentence:\n",
    "        word_count[word]+=1\n",
    "frequency=word_count.most_common(50000)        \n",
    "total_fr_words=len(frequency)+2\n",
    "fr_word_dict={w[0]:idx+2 for idx,w in enumerate(frequency)}\n",
    "fr_word_dict[\"PAD\"]=PAD\n",
    "fr_word_dict[\"UNK\"]=UNK\n",
    "fr_idx_dict={v:k for k,v in fr_word_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e4e843fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[28, 40, 231, 32, 726, 370, 4]\n"
     ]
    }
   ],
   "source": [
    "fridx=[fr_word_dict.get(i,UNK) for i in tokenized_fr]   \n",
    "print(fridx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e8a78dd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['je</w>', 'ne</w>', 'parle</w>', 'pas</w>', 'franc', 'ais</w>', '.</w>']\n",
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "frtokens=[fr_idx_dict.get(i,\"UNK\") for i in fridx]   \n",
    "print(frtokens)\n",
    "fr_phrase=\"\".join(frtokens)\n",
    "fr_phrase=fr_phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    fr_phrase=fr_phrase.replace(f\" {x}\",f\"{x}\")  \n",
    "print(fr_phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9ec5627f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[452, 61, 742, 30]\n",
      "['comment</w>', 'et', 'es-vous</w>', '?</w>']\n",
      "comment etes-vous? \n"
     ]
    }
   ],
   "source": [
    "# exercise 10.2\n",
    "tokens=['comment</w>', 'et', 'es-vous</w>', '?</w>']\n",
    "indexes=[fr_word_dict.get(i,UNK) for i in tokens]   \n",
    "print(indexes)\n",
    "tokens=[fr_idx_dict.get(i,\"UNK\") for i in indexes]   \n",
    "print(tokens)\n",
    "phrase=\"\".join(tokens)\n",
    "phrase=phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    phrase=phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f77812fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"files/dict.p\",\"wb\") as fb:\n",
    "    pickle.dump((en_word_dict,en_idx_dict,\n",
    "                 fr_word_dict,fr_idx_dict),fb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f214dff",
   "metadata": {},
   "source": [
    "## 1.2. Sequence Padding and Batch Creation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "20480c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_en_ids=[[en_word_dict.get(w,UNK) for w in s] for s in en_tokens]\n",
    "out_fr_ids=[[fr_word_dict.get(w,UNK) for w in s] for s in fr_tokens]\n",
    "sorted_ids=sorted(range(len(out_en_ids)),\n",
    "                  key=lambda x:len(out_en_ids[x]))\n",
    "out_en_ids=[out_en_ids[x] for x in sorted_ids]\n",
    "out_fr_ids=[out_fr_ids[x] for x in sorted_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "91845c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "batch_size=128\n",
    "idx_list=np.arange(0,len(en_tokens),batch_size)\n",
    "np.random.shuffle(idx_list)\n",
    "\n",
    "batch_indexs=[]\n",
    "for idx in idx_list:\n",
    "    batch_indexs.append(np.arange(idx,min(len(en_tokens),\n",
    "                                          idx+batch_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4bec238e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq_padding(X, padding=0):\n",
    "    L = [len(x) for x in X]\n",
    "    ML = max(L)\n",
    "    padded_seq = np.array([np.concatenate([x, [padding] * (ML - len(x))])\n",
    "        if len(x) < ML else x for x in X])\n",
    "    return padded_seq"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73421610",
   "metadata": {},
   "source": [
    "The following class is defined in the local module ch09util.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1f9ac0b",
   "metadata": {},
   "source": [
    "```Python\n",
    "import torch\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# define the Batch class\n",
    "class Batch:\n",
    "    def __init__(self, src, trg=None, pad=0):\n",
    "        src = torch.from_numpy(src).to(DEVICE).long()\n",
    "        self.src = src\n",
    "        self.src_mask = (src != pad).unsqueeze(-2)\n",
    "        if trg is not None:\n",
    "            trg = torch.from_numpy(trg).to(DEVICE).long()\n",
    "            self.trg = trg[:, :-1]\n",
    "            self.trg_y = trg[:, 1:]\n",
    "            self.trg_mask = make_std_mask(self.trg, pad)\n",
    "            self.ntokens = (self.trg_y != pad).data.sum()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b38389fa",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "def subsequent_mask(size):\n",
    "    attn_shape = (1, size, size)\n",
    "    subsequent_mask = np.triu(np.ones(attn_shape),\n",
    "                              k=1).astype('uint8')\n",
    "    output = torch.from_numpy(subsequent_mask) == 0\n",
    "    return output\n",
    "\n",
    "def make_std_mask(tgt, pad):\n",
    "    tgt_mask = (tgt != pad).unsqueeze(-2)\n",
    "    output = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)\n",
    "    return output \n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6e37251c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import Batch\n",
    "\n",
    "class BatchLoader():\n",
    "    def __init__(self):\n",
    "        self.idx=0\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    def __next__(self):\n",
    "        self.idx += 1\n",
    "        if self.idx<=len(batch_indexs):\n",
    "            b=batch_indexs[self.idx-1]\n",
    "            batch_en=[out_en_ids[x] for x in b]\n",
    "            batch_fr=[out_fr_ids[x] for x in b]\n",
    "            batch_en=seq_padding(batch_en)\n",
    "            batch_fr=seq_padding(batch_fr)\n",
    "            return Batch(batch_en,batch_fr)\n",
    "        raise StopIteration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4450f4fe",
   "metadata": {},
   "source": [
    "# 2\tWord embedding and positional encoding\n",
    "## 2.1. Word Embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "05d352a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "there are 11055 distinct English tokens\n",
      "there are 11239 distinct French tokens\n"
     ]
    }
   ],
   "source": [
    "src_vocab = len(en_word_dict)\n",
    "tgt_vocab = len(fr_word_dict)\n",
    "print(f\"there are {src_vocab} distinct English tokens\")\n",
    "print(f\"there are {tgt_vocab} distinct French tokens\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "139037be",
   "metadata": {},
   "source": [
    "```python\n",
    "import math\n",
    "\n",
    "class Embeddings(nn.Module):\n",
    "    def __init__(self, d_model, vocab):\n",
    "        super().__init__()\n",
    "        self.lut = nn.Embedding(vocab, d_model)\n",
    "        self.d_model = d_model\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.lut(x) * math.sqrt(self.d_model)\n",
    "        return out\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f17ecb",
   "metadata": {},
   "source": [
    "## 2.1. Positional Encoding\n",
    "To model the order of elements in the input and output sequences, we'll first create positional encodings of the sequences as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d886689",
   "metadata": {},
   "source": [
    "```python\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, dropout, max_len=5000):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "        pe = torch.zeros(max_len, d_model, device=DEVICE)\n",
    "        position = torch.arange(0., max_len, \n",
    "                                device=DEVICE).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(\n",
    "            0., d_model, 2, device=DEVICE)\n",
    "            * -(math.log(10000.0) / d_model))\n",
    "        pe_pos = torch.mul(position, div_term)\n",
    "        pe[:, 0::2] = torch.sin(pe_pos)\n",
    "        pe[:, 1::2] = torch.cos(pe_pos)\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)  \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:, :x.size(1)].requires_grad_(False)\n",
    "        out = self.dropout(x)\n",
    "        return out\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9dffa6e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the shape of positional encoding is torch.Size([1, 8, 256])\n",
      "tensor([[[ 0.0000e+00,  1.1111e+00,  0.0000e+00,  ...,  1.1111e+00,\n",
      "           0.0000e+00,  1.1111e+00],\n",
      "         [ 9.3497e-01,  6.0034e-01,  8.9107e-01,  ...,  1.1111e+00,\n",
      "           1.1940e-04,  1.1111e+00],\n",
      "         [ 1.0103e+00, -4.6239e-01,  1.0646e+00,  ...,  1.1111e+00,\n",
      "           2.3880e-04,  1.1111e+00],\n",
      "         ...,\n",
      "         [-1.0655e+00,  3.1518e-01, -1.1091e+00,  ...,  1.1111e+00,\n",
      "           5.9700e-04,  1.1111e+00],\n",
      "         [-3.1046e-01,  1.0669e+00, -7.1559e-01,  ...,  1.1111e+00,\n",
      "           7.1640e-04,  1.1111e+00],\n",
      "         [ 7.2999e-01,  0.0000e+00,  2.5419e-01,  ...,  1.1111e+00,\n",
      "           8.3581e-04,  1.1111e+00]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "from utils.ch09util import PositionalEncoding\n",
    "import torch\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "pe = PositionalEncoding(256, 0.1)\n",
    "x = torch.zeros(1, 8, 256).to(DEVICE)\n",
    "y = pe.forward(x)\n",
    "print(f\"the shape of positional encoding is {y.shape}\")\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0df1b8ba",
   "metadata": {},
   "source": [
    "# 3\tTrain the Transformer for English-to-French translation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4768df6",
   "metadata": {},
   "source": [
    "## 3.1 Loss Function and the Optimizer\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "22e51429",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import create_model\n",
    "\n",
    "model = create_model(src_vocab, tgt_vocab, N=6,\n",
    "    d_model=256, d_ff=1024, h=8, dropout=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf29d76a",
   "metadata": {},
   "source": [
    "We define the following class in the local module:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77c8742a",
   "metadata": {},
   "source": [
    "```python\n",
    "class LabelSmoothing(nn.Module):\n",
    "    def __init__(self, size, padding_idx, smoothing=0.1):\n",
    "        super().__init__()\n",
    "        self.criterion = nn.KLDivLoss(reduction='sum')  \n",
    "        self.padding_idx = padding_idx\n",
    "        self.confidence = 1.0 - smoothing\n",
    "        self.smoothing = smoothing\n",
    "        self.size = size\n",
    "        self.true_dist = None\n",
    "\n",
    "    def forward(self, x, target):\n",
    "        assert x.size(1) == self.size\n",
    "        true_dist = x.data.clone()\n",
    "        true_dist.fill_(self.smoothing / (self.size - 2))\n",
    "        true_dist.scatter_(1, \n",
    "               target.data.unsqueeze(1), self.confidence)\n",
    "        true_dist[:, self.padding_idx] = 0\n",
    "        mask = torch.nonzero(target.data == self.padding_idx)\n",
    "        if mask.dim() > 0:\n",
    "            true_dist.index_fill_(0, mask.squeeze(), 0.0)\n",
    "        self.true_dist = true_dist\n",
    "        output = self.criterion(x, true_dist.clone().detach())\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61cd3b51",
   "metadata": {},
   "source": [
    "```python\n",
    "class NoamOpt:\n",
    "    def __init__(self, model_size, factor, warmup, optimizer):\n",
    "        self.optimizer = optimizer\n",
    "        self._step = 0\n",
    "        self.warmup = warmup\n",
    "        self.factor = factor\n",
    "        self.model_size = model_size\n",
    "        self._rate = 0\n",
    "\n",
    "    def step(self):\n",
    "        self._step += 1\n",
    "        rate = self.rate()\n",
    "        for p in self.optimizer.param_groups:\n",
    "            p['lr'] = rate\n",
    "        self._rate = rate\n",
    "        self.optimizer.step()\n",
    "\n",
    "    def rate(self, step=None):\n",
    "        if step is None:\n",
    "            step = self._step\n",
    "        output = self.factor * (self.model_size ** (-0.5) *\n",
    "        min(step ** (-0.5), step * self.warmup ** (-1.5)))\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a073dab3",
   "metadata": {},
   "source": [
    "We create the optimizer for training as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "29a4ab8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import NoamOpt\n",
    "\n",
    "optimizer = NoamOpt(256, 1, 2000, torch.optim.Adam(\n",
    "    model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b2434c4",
   "metadata": {},
   "source": [
    "To create the loss function for training, we first define the following class in the local module:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4602ce37",
   "metadata": {},
   "source": [
    "```python\n",
    "class SimpleLossCompute:\n",
    "    def __init__(self, generator, criterion, opt=None):\n",
    "        self.generator = generator\n",
    "        self.criterion = criterion\n",
    "        self.opt = opt\n",
    "\n",
    "    def __call__(self, x, y, norm):\n",
    "        x = self.generator(x)\n",
    "        loss = self.criterion(x.contiguous().view(-1, x.size(-1)),\n",
    "                              y.contiguous().view(-1)) / norm\n",
    "        loss.backward()\n",
    "        if self.opt is not None:\n",
    "            self.opt.step()\n",
    "            self.opt.optimizer.zero_grad()\n",
    "        return loss.data.item() * norm.float()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5163f7f",
   "metadata": {},
   "source": [
    "We then define the loss function as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5b0f9cc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import (LabelSmoothing,\n",
    "       SimpleLossCompute)\n",
    "\n",
    "criterion = LabelSmoothing(tgt_vocab, \n",
    "                           padding_idx=0, smoothing=0.1)\n",
    "loss_func = SimpleLossCompute(\n",
    "            model.generator, criterion, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062599ec",
   "metadata": {},
   "source": [
    "## 3.2 The training loop\n",
    "We'll train the model for 100 epochs. We'll calculate the loss and the number of tokens from each batch. After each epoch, we calculate the average loss in the epoch as the ratio between the total loss and the total number of tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0146ee51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train for 100 epochs\n",
    "for epoch in range(100):\n",
    "    model.train()\n",
    "    tloss=0\n",
    "    tokens=0\n",
    "    for batch in BatchLoader():\n",
    "        out = model(batch.src, batch.trg, \n",
    "                    batch.src_mask, batch.trg_mask)\n",
    "        loss = loss_func(out, batch.trg_y, batch.ntokens)\n",
    "        tloss += loss\n",
    "        tokens += batch.ntokens\n",
    "    print(f\"Epoch {epoch}, average loss: {tloss/tokens}\")\n",
    "torch.save(model.state_dict(),\"files/en2fr.pth\")   "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3274397",
   "metadata": {},
   "source": [
    "The above training process takes a couple of hours if you are using a GPU. It may take several hours if you are using CPU training. Once the training is done, the model weights are saved as *en2fr.pth* on your computer. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f83ef71",
   "metadata": {},
   "source": [
    "# 4. Translate English to French with the Trained Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "a4779970",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(eng):\n",
    "    # tokenize the English sentence\n",
    "    tokenized_en=tokenizer.tokenize(eng)\n",
    "    # add beginning and end tokens\n",
    "    tokenized_en=[\"BOS\"]+tokenized_en+[\"EOS\"]\n",
    "    # convert tokens to indexes\n",
    "    enidx=[en_word_dict.get(i,UNK) for i in tokenized_en]  \n",
    "    src=torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)\n",
    "    # create mask to hide padding\n",
    "    src_mask=(src!=0).unsqueeze(-2)\n",
    "    # encode the English sentence\n",
    "    memory=model.encode(src,src_mask)\n",
    "    # start translation in an autogressive fashion\n",
    "    start_symbol=fr_word_dict[\"BOS\"]\n",
    "    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)\n",
    "    translation=[]\n",
    "    for i in range(100):\n",
    "        out = model.decode(memory,src_mask,ys,\n",
    "        subsequent_mask(ys.size(1)).type_as(src.data))\n",
    "        prob = model.generator(out[:, -1])\n",
    "        _, next_word = torch.max(prob, dim=1)\n",
    "        next_word = next_word.data[0]\n",
    "        ys = torch.cat([ys, torch.ones(1, 1).type_as(\n",
    "            src.data).fill_(next_word)], dim=1)\n",
    "        sym = fr_idx_dict[ys[0, -1].item()]\n",
    "        if sym != 'EOS':\n",
    "            translation.append(sym)\n",
    "        else:\n",
    "            break\n",
    "    # convert tokens to sentences\n",
    "    trans=\"\".join(translation)\n",
    "    trans=trans.replace(\"</w>\",\" \") \n",
    "    for x in '''?:;.,'(\"-!&)%''':\n",
    "        trans=trans.replace(f\" {x}\",f\"{x}\")    \n",
    "    print(trans)\n",
    "    return trans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3e8c99d",
   "metadata": {},
   "source": [
    "Let's try the defined function on the English phrase \"Today is a beautiful day!\", like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5a2af177",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aujourd'hui est une belle journee! \n"
     ]
    }
   ],
   "source": [
    "from utils.ch09util import subsequent_mask\n",
    "\n",
    "with open(\"files/dict.p\",\"rb\") as fb:\n",
    "    en_word_dict,en_idx_dict,\\\n",
    "    fr_word_dict,fr_idx_dict=pickle.load(fb)\n",
    "trained_weights=torch.load(\"files/en2fr.pth\",\n",
    "                           map_location=DEVICE)\n",
    "model.load_state_dict(trained_weights)\n",
    "model.eval()\n",
    "eng = \"Today is a beautiful day!\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "6c4fea34",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "un petit garcon en jeans grimpe un petit arbre tandis qu'un autre enfant regarde. \n"
     ]
    }
   ],
   "source": [
    "eng = \"A little boy in jeans climbs a small tree while another child looks on.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "03d060f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "eng = \"I don't speak French.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f342772",
   "metadata": {},
   "source": [
    "Now let's try the sentence \"I do not speak French.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f39f8649",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "eng = \"I do not speak French.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "279e4604",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "j'aime le ski en hiver! \n",
      "comment etes-vous? \n"
     ]
    }
   ],
   "source": [
    "# exercise 10.3\n",
    "eng = \"I love skiing in the winter!\"\n",
    "translated_fr = translate(eng)\n",
    "eng = \"How are you?\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2bac19f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
